package model

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"math"

	"github.com/Masterminds/squirrel"
	"github.com/zeromicro/go-zero/core/stores/cache"
	"github.com/zeromicro/go-zero/core/stores/sqlc"
	"github.com/zeromicro/go-zero/core/stores/sqlx"
)

var _ AccountModel = (*customAccountModel)(nil)

type (
	// AccountModel is an interface to be customized, add more methods here,
	// and implement the added methods in customAccountModel.
	AccountModel interface {
		accountModel
		RowBuilder() squirrel.SelectBuilder
		Truncate(ctx context.Context)
		FindAll(ctx context.Context) ([]*Account, error)
		AddField(ctx context.Context, id int64, field string, count int64, precision uint8) error
		DeductField(ctx context.Context, id int64, field string, count int64, precision uint8) error
		AdjustTrading(ctx context.Context, id int64, field string, count int64, precision uint8) error
		AdjustBalance(ctx context.Context, id int64, field string, count int64, precision uint8) error
		FindAllEx(Current int64, PageSize int64) (*[]Account, error)
		Count() (int64, error)
	}

	customAccountModel struct {
		*defaultAccountModel
	}
)

// NewAccountModel returns a model for the database table.
func NewAccountModel(conn sqlx.SqlConn, c cache.CacheConf) AccountModel {
	return &customAccountModel{
		defaultAccountModel: newAccountModel(conn, c),
	}
}

func (m *defaultAccountModel) RowBuilder() squirrel.SelectBuilder {
	return squirrel.Select(accountRows).From(m.table)
}

func (m defaultAccountModel) Truncate(ctx context.Context) {
	build := m.RowBuilder()
	query, values, err := build.ToSql()
	if err == nil {
		var resp []*Account
		err = m.QueryRowsNoCacheCtx(ctx, &resp, query, values...)
		if err == nil {
			for _, r := range resp {
				m.Delete(ctx, r.Id)
			}
			m.ExecNoCacheCtx(ctx, fmt.Sprintf("TRUNCATE %s", m.tableName()))
		}
	}
}

func (m defaultAccountModel) FindAll(ctx context.Context) ([]*Account, error) {
	build := m.RowBuilder()
	query, values, err := build.ToSql()
	if err != nil {
		return nil, nil
	}

	var resp []*Account
	err = m.QueryRowsNoCacheCtx(ctx, &resp, query, values...)
	switch err {
	case nil:
		return resp, nil
	default:
		return nil, err
	}
}

func (m defaultAccountModel) AddField(ctx context.Context, id int64, field string, count int64, precision uint8) error {
	accountIdKey := fmt.Sprintf("%s%v", cacheAccountIdPrefix, id)

	count = count * int64(math.Pow10(int(precision)))
	result, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
		query := fmt.Sprintf("update %s set %s=%s+? where `id`=? ", m.table, field, field)
		return conn.ExecCtx(ctx, query, count, id)
	}, accountIdKey)

	if err != nil {
		return err
	}

	if n, err := result.RowsAffected(); err != nil {
		return err
	} else {
		if n <= 0 {
			return errors.New("更新失败")
		}
	}

	return err
}

func (m defaultAccountModel) DeductField(ctx context.Context, id int64, field string, count int64, precision uint8) error {
	accountIdKey := fmt.Sprintf("%s%v", cacheAccountIdPrefix, id)

	count = count * int64(math.Pow10(int(precision)))
	result, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
		query := fmt.Sprintf("update %s set %s=%s-? where `id` = ? and %s>=?", m.table, field, field, field)
		return conn.ExecCtx(ctx, query, count, id, count)
	}, accountIdKey)

	if err != nil {
		return err
	}

	if n, err := result.RowsAffected(); err != nil {
		return err
	} else {
		if n <= 0 {
			return errors.New("更新失败")
		}
	}

	return nil
}

func (m defaultAccountModel) AdjustTrading(ctx context.Context, id int64, field string, count int64, precision uint8) error {
	accountIdKey := fmt.Sprintf("%s%v", cacheAccountIdPrefix, id)

	count = count * int64(math.Pow10(int(precision)))

	result, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
		query := fmt.Sprintf(" update %s set %s_trading=%s_trading+1*? where id=? and %s_trading + 1*? + %s >= 0 ", m.table, field, field, field, field)
		return conn.ExecCtx(ctx, query, count, id, count)
	}, accountIdKey)

	if err != nil {
		return err
	}

	if n, err := result.RowsAffected(); err != nil {
		return err
	} else {
		if n <= 0 {
			return errors.New("更新失败")
		}
	}

	return err
}

func (m defaultAccountModel) AdjustBalance(ctx context.Context, id int64, field string, count int64, precision uint8) error {
	accountIdKey := fmt.Sprintf("%s%v", cacheAccountIdPrefix, id)

	count = count * int64(math.Pow10(int(precision)))
	result, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
		query := fmt.Sprintf("update %s set %s_trading=%s_trading-1*?, %s=%s+1*? where id=? ", m.table, field, field, field, field)
		return conn.ExecCtx(ctx, query, count, count, id)
	}, accountIdKey)

	if err != nil {
		return err
	}

	if n, err := result.RowsAffected(); err != nil {
		return err
	} else {
		if n <= 0 {
			return errors.New("更新失败")
		}
	}

	return nil
}

func (m *defaultAccountModel) FindAllEx(Current int64, PageSize int64) (*[]Account, error) {

	//query := fmt.Sprintf("select %s from %s limit ?,?", sysUserRows, m.table)
	query := "select * from account  limit ?,?"
	var resp []Account
	err := m.CachedConn.QueryRowsNoCache(&resp, query, (Current-1)*PageSize, PageSize)
	switch err {
	case nil:
		return &resp, nil
	case sqlc.ErrNotFound:
		return nil, ErrNotFound
	default:
		return nil, err
	}
}

func (m *defaultAccountModel) Count() (int64, error) {
	query := fmt.Sprintf("select count(*) as count from %s", m.table)

	var count int64
	err := m.CachedConn.QueryRowNoCache(&count, query)

	switch err {
	case nil:
		return count, nil
	case sqlc.ErrNotFound:
		return 0, ErrNotFound
	default:
		return 0, err
	}
}
